## MSB custom network implementation for a spatially embedded neuromodulator RNN

import torch as tch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.spatial as ss
import torch.nn.utils.parametrize as parametrize
from torch import nn, jit
import math

class SpatialWeight(nn.Module):
    def __init__(self, input_size, observable_size = 64, ell = 0.1, N_nm=4):
        super(SpatialWeight, self).__init__()
        
        np.random.seed(1)
        self.pos = nn.Parameter(torch.tensor(np.random.random([observable_size,2])),requires_grad=False)
    
        #let's calculate the synpatic matrix.
        self.delpoints = ss.distance.cdist(self.pos,self.pos)
        self.delpoints = self.delpoints[:,:, None]*np.ones([observable_size, observable_size, N_nm])
        #self.delpoints = torch.tensor(self.delpoints, dtype=torch.float32)
        self.ell = ell
        pinhib = 0.5 # hardcoded below
        self.scale = 1
        #inhib = torch.multinomial(torch.tensor(np.array([0.,1.])), observable_size, replacement=True)#np.random.choice([0,1], observable_size,[1-pinhib, pinhib])
        self.inhib = torch.tensor((np.random.choice([0,1], observable_size,[1-pinhib, pinhib]))[:,None,None]*np.ones(self.delpoints.shape)).float()

        # Define the relative distances and don't let them move.
        self.Delta = nn.Parameter(torch.tensor(self.delpoints/self.ell).float(), requires_grad = False)
        self.mask = nn.Parameter(torch.tensor(np.logical_and(self.delpoints<5*self.ell, np.eye(observable_size)[:,:,None]*np.ones(self.delpoints.shape) == 0)).float(), requires_grad=False)
        
    def forward(self, W):
        return self.scale*(-1)**self.inhib*torch.exp(W-self.Delta)*self.mask


class spatial_nmRNNCell_base(jit.ScriptModule): #(nn.Module):#
#     __constants__ = ['bias']
    
    def __init__(self, N_NM, input_size, hidden_size, nonlinearity, bias, keepW0 = False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.N_nm = N_NM
        self.keepW0 = keepW0
        self.g = 10

        self.spatialNet = SpatialWeight(input_size, observable_size = hidden_size)

        self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
        #self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size, N_NM))
        self.weight_hh = nn.Parameter(self.spatialNet(torch.Tensor(hidden_size, hidden_size, N_NM)))
        self.weight_h2nm = nn.Parameter(torch.Tensor(N_NM, hidden_size))
        self.weight_nm2nm = nn.Parameter(torch.Tensor(N_NM, N_NM))
        if keepW0:
            self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        else:
            self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = False)
            #self.register_parameter('weight0_hh', None)
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))    #, nonlinearity=nonlinearity)
        nn.init.kaiming_uniform_(self.weight_hh, a=self.g/math.sqrt(self.hidden_size))    #, nonlinearity=nonlinearity)
        nn.init.sparse_(self.weight_h2nm, 0.1)
        nn.init.zeros_(self.weight_nm2nm)

        if self.keepW0:
            nn.init.kaiming_uniform_(self.weight0_hh, a=math.sqrt(5))
        else:
            nn.init.zeros_(self.weight0_hh)
        
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_ih)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)


            
class s_nmRNNCell(spatial_nmRNNCell_base):  # Euler integration of rate-neuron network dynamics 
    def __init__(self, N_nm, input_size, hidden_size, nonlinearity = None, decay = 0, bias = True, keepW0 = True):
        super().__init__(N_nm, input_size, hidden_size, nonlinearity, bias)
        self.decay = decay    #  torch.exp( - dt/tau )
        self.N_nm = N_nm

    def forward(self, input, hiddenCombined):
        # start by disentangling the NMs from the Hidden Units
        if self.N_nm>0:
            hidden = hiddenCombined[:,:,0:-self.N_nm]
            nm = hiddenCombined[:,:,-self.N_nm::]
            #print(hiddenCombined.shape, hidden.shape, nm.shape)
        else:
            hidden = hiddenCombined
            nm = None
        if self.bias == None:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('tbj, ijk, tbk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t())
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t())
        else:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('bj, ijk, bk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t() + self.bias)
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t() + self.bias)
        if nm != None:
            activity_nm = self.nonlinearity(hidden @ self.weight_h2nm.t() + nm @ self.weight_nm2nm.t())
            nm = self.decay * nm + (1-self.decay) * activity_nm
        hidden   = self.decay * hidden + (1 - self.decay) * activity
        return torch.cat([hidden, nm], dim = 2)

class s_nmRNNLayer(nn.Module): 
    """This behavses very similarly to nn.RNN() but returns the NM state appended to the hiddenstate along the dimension of the tensor."""
    def __init__(self, N_nm, input_size, hidden_size, nonlinearity, decay = 0.9, bias = False, keepW0 = False):
        super().__init__()
        self.rnncell = s_nmRNNCell(N_nm, input_size, hidden_size, nonlinearity = nonlinearity, decay = decay, bias = bias, keepW0 = keepW0)
        self.N_nm = N_nm

    def forward(self, input, initH):
        #print('in the layer ', initH[0].shape, initH[1].shape)
        inputs = input.unbind(0)     # inputs has dimension [Time, batch n_input]
        hidden = initH      # initial state has dimension [1, batch, n_rnn]
        outputs = []
        nm_out = []
        for i in range(len(inputs)):  # looping over the time dimension 
            hidden = self.rnncell(inputs[i], hidden)
            outputs += [hidden.squeeze(0)]       # vanilla RNN directly outputs the hidden state
        return torch.stack(outputs), hidden


